from functools import wraps

from sentry_sdk.integrations import DidNotEnable
import sentry_sdk

from ..spans import execute_tool_span, update_execute_tool_span
from ..utils import (
    _capture_exception,
    get_current_agent,
)

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from typing import Any

try:
    from pydantic_ai.mcp import MCPServer  # type: ignore

    HAS_MCP = True
except ImportError:
    HAS_MCP = False

try:
    from pydantic_ai._tool_manager import ToolManager  # type: ignore
except ImportError:
    raise DidNotEnable("pydantic-ai not installed")


def _patch_tool_execution():
    # type: () -> None
    """
    Patch ToolManager._call_tool to create execute_tool spans.

    This is the single point where ALL tool calls flow through in pydantic_ai,
    regardless of toolset type (function, MCP, combined, wrapper, etc.).

    By patching here, we avoid:
    - Patching multiple toolset classes
    - Dealing with signature mismatches from instrumented MCP servers
    - Complex nested toolset handling
    """

    original_call_tool = ToolManager._call_tool

    @wraps(original_call_tool)
    async def wrapped_call_tool(self, call, *args, **kwargs):
        # type: (Any, Any, *Any, **Any) -> Any

        # Extract tool info before calling original
        name = call.tool_name
        tool = self.tools.get(name) if self.tools else None

        # Determine tool type by checking tool.toolset
        tool_type = "function"  # default
        if tool and HAS_MCP and isinstance(tool.toolset, MCPServer):
            tool_type = "mcp"

        # Get agent from contextvar
        agent = get_current_agent()

        if agent and tool:
            try:
                args_dict = call.args_as_dict()
            except Exception:
                args_dict = call.args if isinstance(call.args, dict) else {}

            # Create execute_tool span
            # Nesting is handled by isolation_scope() to ensure proper parent-child relationships
            with sentry_sdk.isolation_scope():
                with execute_tool_span(
                    name,
                    args_dict,
                    agent,
                    tool_type=tool_type,
                ) as span:
                    try:
                        result = await original_call_tool(
                            self,
                            call,
                            *args,
                            **kwargs,
                        )
                        update_execute_tool_span(span, result)
                        return result
                    except Exception as exc:
                        _capture_exception(exc)
                        raise exc from None

        # No span context - just call original
        return await original_call_tool(
            self,
            call,
            *args,
            **kwargs,
        )

    ToolManager._call_tool = wrapped_call_tool
